# -*- coding: utf-8 -*-
"""
Created on Fri Sep 17 10:29:29 2021
用 tf2 写一个从 任意态 态制备到 |0> 态的 DQN 算法
@author: Waikikilick
"""

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #设置 tf 的报警等级
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses,initializers,Sequential,metrics,models
import copy 
from collections import deque
import random
from scipy.linalg import expm
from time import *

tf.random.set_seed(1)
np.random.seed(1)
random.seed (1)

class env(object):
    def  __init__(self, 
        action_space = [0,1], #允许的动作，默认两个分立值，只是默认值，真正值由调用时输入
        dt = 0.1,
        ): 
        self.action_space = action_space
        self.n_actions = len(self.action_space)
        self.n_features = 4 #描述状态所用的长度
        self.target_psi =  np.mat([[1], [0]], dtype=complex) #最终的目标态为 |0>,  np.array([1,0,0,0])
        self.s_x = np.mat([[0, 1], [1, 0]], dtype=complex)
        self.s_z = np.mat([[1, 0], [0, -1]], dtype=complex)
        self.dt = dt
        self.training_set, self.validation_set, self.testing_set = self.psi_set()
        
        self.noise_normal = np.array([ 1.62434536, -0.61175641, -0.52817175, -1.07296862,  0.86540763,
       -2.3015387 ,  1.74481176, -0.7612069 ,  0.3190391 , -0.24937038,
        1.46210794, -2.06014071, -0.3224172 , -0.38405435,  1.13376944,
       -1.09989127, -0.17242821, -0.87785842,  0.04221375,  0.58281521,
       -1.10061918,  1.14472371,  0.90159072,  0.50249434,  0.90085595,
       -0.68372786, -0.12289023, -0.93576943, -0.26788808,  0.53035547,
       -0.69166075, -0.39675353, -0.6871727 , -0.84520564, -0.67124613,
       -0.0126646 , -1.11731035,  0.2344157 ,  1.65980218,  0.74204416,
       -0.19183555])
        #noise_normal 为均值为 0 ，标准差为 1 的正态分布的随机数组成的数组
        #该随机数由 np.random.seed(1) 生成: np.random.seed(1) \ noise_uniform = np.random.normal(loc=0.0, scale=1.0, size=41)
        
        self.noise_constant = np.ones(41)
        
        
    def psi_set(self):
        
        theta_num = 6 # 除了 0 和 Pi 两个点之外，点的数量
        varphi_num = 21 # varphi 角度一圈上的点数
        # 总点数为 theta_num * varphi_num + 2(布洛赫球两极) # 6 * 21 + 2 = 512
        
        theta = np.delete(np.linspace(0,np.pi,theta_num+1,endpoint=False),[0])
        varphi = np.linspace(0,np.pi*2,varphi_num,endpoint=False) 
        
        psi_set = []
        for ii in range(theta_num):
            for jj in range(varphi_num):
                psi_set.append(np.mat([[np.cos(theta[ii]/2)],[np.sin(theta[ii]/2)*(np.cos(varphi[jj])+np.sin(varphi[jj])*(0+1j))]]))
        psi_set.append(np.mat([[1], [0]], dtype=complex))
        psi_set.append(np.mat([[0], [1]], dtype=complex))
        random.shuffle(psi_set) # 打乱点集
    
        training_set = psi_set[0:32]
        validation_set = psi_set[32:64]
        testing_set = psi_set[64:128]
        
        return training_set, validation_set, testing_set
        
    
    def reset(self, init_psi): # 在一个新的回合开始时，归位到开始选中的那个点上
        
        init_state = np.array([init_psi[0,0].real, init_psi[1,0].real, init_psi[0,0].imag, init_psi[1,0].imag])
        # np.array([1实，2实，1虚，2虚])
        return init_state
    
    
    def step(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        #array([[1实 + 1虚j, 2实 + 2虚j]])
        
        psi = psi.T
        #array([[1实 + 1虚j],
        #        2实 + 2虚j]])

        psi = np.mat(psi) 
        #matrix([[ 1实 + 1虚j],
        #        [ 2实 + 2虚j]])
        
        H =  float(action)* self.s_z + 1 * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  # next state

        err = 10e-4
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        rwd = (fid)
        
        done = (((1 - fid) < err) or nstep >= 2 * np.pi / self.dt) 

        #再将量子态的 psi 形式恢复到 state 形式。 # 因为网络输入不能为复数，否则无法用寻常基于梯度的算法进行反向传播
    
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, rwd, done, fid   
    
    def step_noise_J(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi)         
        H =  (float(action) + self.noise[nstep]) * self.s_z +  self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, fid   
    
    def step_noise_h(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        psi = psi.T
        psi = np.mat(psi)         
        H =  float(action) * self.s_z +  ( 1 + self.noise[nstep]) * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, fid   
    
def relu(vector):
    for i in range(len(vector)):
        vector[i] = np.maximum(0,vector[i])
    return vector

def agent_matrix(x):
    out = np.mat(x) * net[0].numpy() + net[1].numpy()
    out = relu(out)
    out = out * net[2].numpy() + net[3].numpy()
    out = relu(out)
    out = out * net[4].numpy() + net[5].numpy()
    out = relu(out)
    action = np.argmax(out)
    return action

def testing(): # 测试 测试集中的点 得到 保真度 分布
    print('\n测试中, 请稍等...')
    
    testing_set = env.testing_set
    fid_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)  
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        fid_list.append(fid_max)
        
    return fid_list

def testing_noise_J_dynamic(noise_a): # #测试部分 ( J 动态噪声环境 ）
    print('\n测试 J 动态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_normal #uniform
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_J(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list

def testing_noise_J_static(noise_a): # #测试部分 ( J 动态噪声环境 ）
    print('\n测试 J 静态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_constant 
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_J(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list


def testing_noise_h_dynamic(noise_a): # #测试部分 ( h 动态噪声环境 ）

    print('\n测试 h 动态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_normal 
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_h(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list

def testing_noise_h_static(noise_a): # #测试部分 ( h 动态噪声环境 ）

    print('\n测试 h 静态噪声中, 请稍等...')
    print('噪声振幅为：', noise_a)
    
    env.noise = noise_a * env.noise_constant 
    fid_noise_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        action_list = [] #用来保存本回合所采取的动作，用于噪声分析
        fid_list = [] #用来保存本回合中的保真度，选择最大保真度对应的步骤作为后面噪声环境中动作的终止步骤
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)              
            action_list.append(action)
            fid_list.append(fid)
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        max_index = fid_list.index(max(fid_list))
        action_list = action_list[0:max_index+1]
        
        observation = env.reset(test_init_psi)
        
        # 加入噪声
        nstep = 0
        for action in action_list:
            
            observation_, fid = env.step_noise_h(observation, action, nstep)
            observation = observation_
            nstep += 1
             #选择最后一步的保真度作为本回合的保真度
            
        fid_noise_list.append(fid) # 将最终保真度记录到矩阵中
            
    return fid_noise_list
                 
    
if __name__ == "__main__":
    
    dt = np.pi/10
    
    network = keras.models.load_model('dqn_tf2_single_qubit_to_0_saved_model') #导入训练好的网络 
    net = network.trainable_variables   
    
    env = env(action_space = list(range(4)),   #允许的动作数 0 ~ 4-1 也就是 4 个
               dt = dt)        
    testing_set = env.testing_set   
    
    

    # 测试
    testing_fid_list = testing()
    print('测试集平均保真度为：', np.mean(testing_fid_list))
    
    # 测试 噪声 对 保真度 的影响
    fids_noise_list = []
    
    # for noise_a in [-0.1, -0.09, -0.08, -0.07, -0.06, -0.05, -0.04, -0.03, -0.02, -0.01, 0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]:
    #     test_noise_mean = np.mean(testing_noise_J_static(noise_a))
    #     # test_noise_mean = np.mean(testing_noise_h_static(noise_a))
    #     print('噪声下的平均保真度为：',test_noise_mean)
    #     fids_noise_list.append(test_noise_mean)
        
    for noise_a in [0, 0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.2]:
        # test_noise_mean = np.mean(testing_noise_J_dynamic(noise_a))
        test_noise_mean = np.mean(testing_noise_h_dynamic(noise_a))
        print('噪声下的平均保真度为：',test_noise_mean)
        fids_noise_list.append(test_noise_mean) 
    
    print('测试集在噪声下平均保真度为', fids_noise_list)
    
   # J_static: [0.9720746964206154, 0.9759220075795664, 0.9795182242071243, 0.9828381385443261, 0.9858562260014769, 0.9885468831169202, 0.990884677782273, 0.9928446091000438, 0.9944023740415062, 0.9955346379073958, 0.9962193054630456, 0.9964357895266029, 0.996165273735105, 0.9953909662003306, 0.994098340795674, 0.9922753628866732, 0.9899126964316988, 0.9870038895348667, 0.9835455357288267, 0.9795374084996027, 0.9749825668360144]
   #           [0.9720747  0.97592201 0.97951822 0.98283814 0.98585623 0.98854688 0.99088468 0.99284461 0.99440237 0.99553464 0.99621931 0.99643579 0.99616527 0.99539097 0.99409834 0.99227536 0.9899127  0.98700389 0.98354554 0.97953741 0.97498257]

   # h_static: [0.9808062532630111, 0.9834450147532554, 0.9858557645568027, 0.9880323590437721, 0.989968745007979, 0.9916589758312547, 0.993097228314088, 0.9942778201449627, 0.995195227976117, 0.9958441060684189, 0.9962193054630456, 0.9963158936326311, 0.996129174559304, 0.9956547091820361, 0.9948883361504977, 0.9938261928176746, 0.9924647363984773, 0.9908007652166777, 0.9888314399580014, 0.9865543048424735, 0.9839673086250874]
   #           [0.98080625 0.98344501 0.98585576 0.98803236 0.98996875 0.99165898 0.99309723 0.99427782 0.99519523 0.99584411 0.99621931 0.99631589 0.99612917 0.99565471 0.99488834 0.99382619 0.99246474 0.99080077 0.98883144 0.9865543  0.98396731]

   # J_dynamic: [0.9962193054630456, 0.9959601369973878, 0.9955143292541403, 0.994881000208637, 0.9940593786420165, 0.9930488079489566, 0.9918487502154562, 0.9904587905442214, 0.9888786416011632, 0.9871081483523769, 0.9851472929572506]
                # [0.99621931 0.99596014 0.99551433 0.994881   0.99405938 0.99304881 0.99184875 0.99045879 0.98887864 0.98710815 0.98514729]

    # h_dynamic: [0.9962193054630456, 0.995482175288021, 0.9941394447629497, 0.992190390597659, 0.9896367134462231, 0.9864825504516801, 0.9827344719569449, 0.9784014622339485, 0.9734948842246589, 0.9680284284314346, 0.9620180462384569]
                # [0.99621931 0.99548218 0.99413944 0.99219039 0.98963671 0.98648255 0.98273447 0.97840146 0.97349488 0.96802843 0.96201805]